import argparse

import numpy as np
import torch as th
import torch.optim as optim
from dgl.data import PPIDataset
from dgl.dataloading import GraphDataLoader
from sklearn.metrics import f1_score

from model import GeniePath, GeniePathLazy


def evaluate(model, loss_fn, dataloader, device='cpu'):
    loss = 0
    f1 = 0
    num_blocks = 0
    for subgraph in dataloader:
        subgraph = subgraph.to(device)
        label = subgraph.ndata['label'].to(device)
        feat = subgraph.ndata['feat']
        logits = model(subgraph, feat)

        # compute loss
        loss += loss_fn(logits, label).item()
        predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0)
        f1 += f1_score(label.cpu(), predict, average='micro')
        num_blocks += 1

    return f1 / num_blocks, loss / num_blocks


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load dataset
    train_dataset = PPIDataset(mode='train')
    valid_dataset = PPIDataset(mode='valid')
    test_dataset = PPIDataset(mode='test')
    train_dataloader = GraphDataLoader(train_dataset, batch_size=args.batch_size)
    valid_dataloader = GraphDataLoader(valid_dataset, batch_size=args.batch_size)
    test_dataloader = GraphDataLoader(test_dataset, batch_size=args.batch_size)

    # check cuda
    if args.gpu >= 0 and th.cuda.is_available():
        device = 'cuda:{}'.format(args.gpu)
    else:
        device = 'cpu'

    num_classes = train_dataset.num_labels

    # Extract node features
    graph = train_dataset[0]
    feat = graph.ndata['feat']

    # Step 2: Create model =================================================================== #
    if args.lazy:
        model = GeniePathLazy(in_dim=feat.shape[-1],
                              out_dim=num_classes,
                              hid_dim=args.hid_dim,
                              num_layers=args.num_layers,
                              num_heads=args.num_heads,
                              residual=args.residual)
    else:
        model = GeniePath(in_dim=feat.shape[-1],
                          out_dim=num_classes,
                          hid_dim=args.hid_dim,
                          num_layers=args.num_layers,
                          num_heads=args.num_heads,
                          residual=args.residual)

    model = model.to(device)

    # Step 3: Create training components ===================================================== #
    loss_fn = th.nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Step 4: training epochs =============================================================== #
    for epoch in range(args.max_epoch):
        model.train()
        tr_loss = 0
        tr_f1 = 0
        num_blocks = 0
        for subgraph in train_dataloader:
            subgraph = subgraph.to(device)
            label = subgraph.ndata['label']
            feat = subgraph.ndata['feat']
            logits = model(subgraph, feat)

            # compute loss
            batch_loss = loss_fn(logits, label)
            tr_loss += batch_loss.item()
            tr_predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0)
            tr_f1 += f1_score(label.cpu(), tr_predict, average='micro')
            num_blocks += 1

            # backward
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

        # validation
        model.eval()
        val_f1, val_loss = evaluate(model, loss_fn, valid_dataloader, device)

        print("In epoch {}, Train F1: {:.4f} | Train Loss: {:.4f}; Valid F1: {:.4f} | Valid loss: {:.4f}".
              format(epoch, tr_f1 / num_blocks, tr_loss / num_blocks, val_f1, val_loss))

    # Test after all epoch
    model.eval()
    test_f1, test_loss = evaluate(model, loss_fn, test_dataloader, device)

    print("Test F1: {:.4f} | Test loss: {:.4f}".
          format(test_f1, test_loss))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GeniePath')
    parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.")
    parser.add_argument("--hid_dim", type=int, default=256, help="Hidden layer dimension")
    parser.add_argument("--num_layers", type=int, default=3, help="Number of GeniePath layers")
    parser.add_argument("--max_epoch", type=int, default=1000, help="The max number of epochs. Default: 1000")
    parser.add_argument("--lr", type=float, default=0.0004, help="Learning rate. Default: 0.0004")
    parser.add_argument("--num_heads", type=int, default=1, help="Number of head in breadth function. Default: 1")
    parser.add_argument("--residual", type=bool, default=False, help="Residual in GAT or not")
    parser.add_argument("--batch_size", type=int, default=2, help="Batch size of graph dataloader")
    parser.add_argument("--lazy", type=bool, default=False, help="Variant GeniePath-Lazy")

    args = parser.parse_args()
    print(args)
    th.manual_seed(16)
    main(args)
